{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2教程-DCGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "import tensorflow as tf\n",
    "import glob\n",
    "import imageio\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import PIL\n",
    "import tensorflow.keras.layers as layers\n",
    "import time\n",
    "\n",
    "from IPython import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.数据导入和预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()\n",
    "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n",
    "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUFFER_SIZE = 60000\n",
    "BATCH_SIZE = 256\n",
    "# Batch and shuffle the data\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.构建模型\n",
    "\n",
    "### 构建生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_generator_model():\n",
    "    model = tf.keras.Sequential()\n",
    "    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n",
    "    model.add(layers.BatchNormalization())\n",
    "    model.add(layers.LeakyReLU())\n",
    "      \n",
    "    model.add(layers.Reshape((7, 7, 256)))\n",
    "    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n",
    "    \n",
    "    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n",
    "    assert model.output_shape == (None, 7, 7, 128)  \n",
    "    model.add(layers.BatchNormalization())\n",
    "    model.add(layers.LeakyReLU())\n",
    "\n",
    "    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n",
    "    assert model.output_shape == (None, 14, 14, 64)    \n",
    "    model.add(layers.BatchNormalization())\n",
    "    model.add(layers.LeakyReLU())\n",
    "\n",
    "    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n",
    "    assert model.output_shape == (None, 28, 28, 1)\n",
    "  \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成器生成图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fab044e90b8>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGHRJREFUeJzt3XuM1eWZB/DvA8Md5OLUYUAEpGgFrUgHMGitF2iFqmg1VNq0LLHSNG2yNLXZBtOuxjRtt2uNtpumsJCCZQu2omJLXRBpgXIRRBy8AiLKHbkP9xnm2T/msJkqv+8zzgznHPt+Pwlh5nznPeedH/Nwzpz3Zu4OEUlPi0J3QEQKQ8UvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJKoknw/Wvn1779y5c2ZeW1vb6PuO2rZowf+fO336NM1LShp/qaJZlC1btqR51Dd2/9H3HV236Puurq5uUnumqdeNMTOa19TU0Dy6bk35N43asr4fPHgQx44d499cTpOK38xuBvAogJYA/tvdf8q+vnPnzpgwYUJmfuTIEfp47AepqqqKtu3UqRPNDx48SPPzzz+f5sypU6do3q1bN5ofOnSo0fffrl072vb48eM0j/q2c+dOmpeVlWVmUQFF/7GwJ5JIVGD79u2j+dGjR2netWtXmrOf1+hntXXr1pnZlClTaNv6Gv2y38xaAvgvAKMADAAwzswGNPb+RCS/mvI7/1AAm9x9s7ufAjAbwJjm6ZaInGtNKf6eALbW+3xb7rZ/YGYTzWyNma05duxYEx5ORJrTOX+3392nuHuFu1e0b9/+XD+ciDRQU4p/O4Be9T6/MHebiHwMNKX4VwPob2Z9zaw1gLsBzGuebonIudbooT53rzGz7wD4X9QN9U1399eidmzcuUuXLrRthw4dMrNo2IgNjwBAx44daf7pT3+60Y/91ltv0Tz6vqOhn02bNmVm0RBlNKTF7huIrysbz37llVdo2yFDhtD84osvpvnLL7+cmUVDlMOHD6f5jh07aB69v7V9e/aL5Oj73rVrV2YWzQmpr0nj/O4+H8D8ptyHiBSGpveKJErFL5IoFb9IolT8IolS8YskSsUvkqi8ruevra2l45/RWDsb716/fj1tG40Jt2nThuZbt27NzD71qU/RttGa+m3bttE8Wj565ZVXNvq+oyW9AwbwhZqtWrWiORsPv+GGG2jbaCw+Wip94MCBzGzQoEG07fvvv0/z6N+kd+/eNGfj8dEcAjZ346PscaBnfpFEqfhFEqXiF0mUil8kUSp+kUSp+EUSldehPoBvxxztwFtZWZmZlZeX07bRbqpz586l+a233pqZLViwgLZt27Ytzf/0pz/R/DOf+QzNX3rppcws2uE22hk4GsqLhmfZkFp0zYcOHUrzDRs20Jz1PRpO27t3L82jXY23bNlCc7bMu7S0tEmP3VB65hdJlIpfJFEqfpFEqfhFEqXiF0mUil8kUSp+kUTlfZyfHS/MtuYG+PLSaDyajYUDwFe+8hWab9y4MTMbNmwYbfvCCy/QfOrUqTSPTl5lY+1NGYcH4uu6cuVKmrMlpqNGjaJtozkI0fbYbEv1aOlrNDcj6lt03QcPHpyZRcuJm3KUfX165hdJlIpfJFEqfpFEqfhFEqXiF0mUil8kUSp+kUQ1aZzfzLYAqAJwGkCNu1fQByspoWuR161bRx+PrWuP1oZH22svWbKE5hMnTszM5s/nBxX369eP5rNmzaJ5tMV19+7dM7PomrKjxwF+lDQAtG/fnubnnXdeZvbEE0/QthUV9McpfOz7778/M3vooYdo22jb8AjbtwLg812OHDlC2/bt2zczi45Mr685Jvnc4O585wMRKTp62S+SqKYWvwNYYGYvmVn262IRKTpNfdl/rbtvN7MLACw0szfd/R9+ec79pzARiPeTE5H8adIzv7tvz/29B8BTAD6046K7T3H3CneviBbuiEj+NLr4zayDmXU68zGAzwN4tbk6JiLnVlNe9pcBeCo3ZFEC4H/c/blm6ZWInHONLn533wwg+2zoszAzOg553XXX0fbsWOThw4fTtvv27aN5NA/g8ccfz8yi47+j9dcDBw6k+auv8hdUF1xwQWbGjoIGgN27d9N806ZNNI+u+7Rp0zKzcePG0bbLly+n+Re+8AWas7kZ/fv3p20vu+yyJuV//etfad6jR4/MrLq6mrZ98803M7MTJ07QtvVpqE8kUSp+kUSp+EUSpeIXSZSKXyRRKn6RROV16+5Tp07Ro4uj7ZS7dOmSmbGlowBf9grEW3v36dMnM7vySj7iOW/ePJr37t2b5jU1NTR/+OGHM7OSEv5PHPUtOtr8qquuovmXvvSlzGzXrl20LbvmALB+/Xqas5+JaHvsaMvy6Aju48eP05wtGY6WA7dokf2czZYKf+h+GvyVIvJPRcUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLyOs7fsmVLunV3mzZtaHu2tLVdu3a0bZTffvvtNGfzAP7whz/Qtmx+AhDPb4i2qGZLghctWkTb3nrrrTSPlhP/7ne/oznDtmIHgL17+abQnTp1ojm7rtEy7KqqKprv2LGD5idPnmz0/UdzDD7KWD6jZ36RRKn4RRKl4hdJlIpfJFEqfpFEqfhFEqXiF0lUXsf5a2pq6DrqSy65hLZn694XLlxI244aNYrm0Vj7mDFjMrPnnuPHFbz33ns0j45kjvrGsPX0QLx1d3Td2DbSAJ+jMGfOHNp27NixNH/77bdpzvoWbWleXl5O83vvvZfmCxYsoPmGDRsys2hbcTYfJtq/oT4984skSsUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLCQUEzmw7gFgB73P3y3G3dAMwB0AfAFgBj3f1AA+6LjllH65jZPuzRePaxY8do3qtXL5qz/e2jsdUDB/iliY4Hj/YD2L9/f2YWzTGI+h6dhxAdkz1z5szMbPLkybRttBdBdF4COxfghhtuoG0rKytp/thjj9E8+nm76KKLaM6wn6do/kJ9DXnm/y2Amz9w2w8ALHL3/gAW5T4XkY+RsPjdfQmADz61jAEwI/fxDAB8GxwRKTqN/Z2/zN3PnDe0C0BZM/VHRPKkyW/4ed3BYpmHi5nZRDNbY2ZrovPLRCR/Glv8u82sHAByf+/J+kJ3n+LuFe5eEW2iKSL509jinwdgfO7j8QCeaZ7uiEi+hMVvZr8HsALApWa2zczuAfBTACPNbCOAEbnPReRjJBznd/dxGdFNH/XBzAytW7fOzNeuXUvbHzp0KDObMGECbRudcT937lyas3Xp0Zhu27ZtaR6tS2fj+AD/3l9++WXalu35DwA/+tGPaP7000/TfPDgwZnZ0qVLadtofsPo0aNpzvaO+OEPf0jb3nfffTSPzpg4ePAgzadNm5aZ3XXXXbTtFVdckZl9lL0fNMNPJFEqfpFEqfhFEqXiF0mUil8kUSp+kUTl/Yjuzp07Z+bRkcxsGeT8+fNp29LSUpofPXqU5idOnMjMou2rhw0bRvMePXrQfNWqVTRnx4ez4S4AuPPOO2keLdl94oknaM6WFF9++eW0bTT82rNnT5pv27YtM3vooYdo2+ho8ui6RsOUP/vZzzKzFStW0LbseO/mXtIrIv+EVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJCqv4/ynT5+mY5TRuC0ba1+8eDFte/vtfI/RaNy2U6dOmVmfPn1oWzbeDMTLP6Nlmuxo85tu4iuvn332WZqfOnWK5n/84x9pftlll2Vmmzdvpm2/9rWv0by6uprmbPn4pk2baNto3keHDh1oHm0r/sgjj2Rm0bHoZtao7IP0zC+SKBW/SKJU/CKJUvGLJErFL5IoFb9IolT8IonK6zg/wLfQjo4tZttQf//736dt2Zp3ABg+fDjNt27dmpn179+ftmXHewPAyJEjaf7zn/+c5uzxozkGI0aMoPmyZctoft1119F89uzZmdmkSZNo2wceeIDm0fHhbH5EbW0tbVt3Cl22aF5ItL/EPffck5lt2LCBtmVbwUf9rk/P/CKJUvGLJErFL5IoFb9IolT8IolS8YskSsUvkqhwnN/MpgO4BcAed788d9sDAO4FcGawc7K784FNAK1atUKvXr0y82hdOzuiOxpX7d69O82jI7yXLFmSmUVjxsePH6f5L3/5S5pHY8psfsTGjRtp25MnT9I8UlLCf4TY2vTx48fTtrfccgvNL730Upq/+OKLmdlPfvIT2vZb3/oWzSNsHwOA780f7aHQu3fvzKxFi4Y/nzfkK38L4Oaz3P6Iuw/K/QkLX0SKS1j87r4EwP489EVE8qgpv/N/x8wqzWy6mXVtth6JSF40tvh/DaAfgEEAdgJ4OOsLzWyima0xszXRvmgikj+NKn533+3up929FsBUAEPJ105x9wp3r4g2PRSR/GlU8ZtZeb1P7wDAjzQVkaLTkKG+3wO4HkCpmW0D8O8ArjezQQAcwBYA3zyHfRSRcyAsfncfd5abpzXmwU6fPo19+/Zl5tH+9507d87MovHN6P2GVatW0bxr1+z3NDt27Ejb9u3bl+alpaU0Z2cGRHl01nuPHj1ofvjwYZp369aN5mxN/Xe/+13advr06TQ/duwYzd9+++3M7Mc//jFty/ZvAIChQzN/0wUQnwvA6uATn/gEbXv69OnMTOv5RSSk4hdJlIpfJFEqfpFEqfhFEqXiF0lUXrfurq2tpctb33nnHdqeDe1EW1RHw2XsOGcAePDBBzOzv/3tb7Tt7t27aR6Jhn5Wr16dmUVbcz/55JM0Hzt2LM2jrb3Zkt9oKfQdd9xBczaUBwCDBg3KzKKhuMGDB9M8Otr8s5/9LM3Z9tvR8CurIQ31iUhIxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9IovJ+RDcbh4zGKNlR1Ndccw1tW1lZSfMuXbrQfPHixZnZFVdcQdtGx4OPGTOG5kuXLqX55z73ucxswYIFtO2QIUNo/ve//53ml1xyCc3ffffdzCw6Fv2FF16gec+ePWnOtoKPtmqP5mZ88YtfpPnKlStpzuadRFu9s6PJ2RLqD9Izv0iiVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJCqv4/wlJSV0m+roiG52THZFRQVtG40p//nPf6Y521Z84cKFtG15eTnNZ82aRXO2VTPAx4XZmnYA2Lx5M82jvkf3f8EFF2Rm0b93VVVVkx6bzTGI5k5E4/zR/Iirr76a5vPmzcvMomPVH3vsscws2u68Pj3ziyRKxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9IoixaQ29mvQDMBFAGwAFMcfdHzawbgDkA+gDYAmCsux9g91VWVuZf/vKXM/Nof/oDB7LvPtp3/8UXX6R5NGa8d+/ezCw6Hpz1GwCeeeYZmt911100v+mmmzKzaDz7G9/4Bs0fffRRmt944400P3nyZGZ24YUX0rbR8eLRPIFLL700Mxs2bBhtO2PGDJpHe+uvXbuW5tdff31mxo7vBvheBHPmzMGePXuM3kFOQ575awB8z90HALgawLfNbACAHwBY5O79ASzKfS4iHxNh8bv7Tndfm/u4CsAbAHoCGAPgzH+PMwDcfq46KSLN7yP9zm9mfQBcBWAVgDJ335mLdqHu1wIR+ZhocPGbWUcATwKY5O6H62de98bBWd88MLOJZrbGzNZEe5OJSP40qPjNrBXqCn+Wu8/N3bzbzMpzeTmAPWdr6+5T3L3C3SvatWvXHH0WkWYQFr+ZGYBpAN5w91/Ui+YBGJ/7eDwA/pa1iBSVhgz1XQtgKYD1AM6cqTwZdb/3PwHgIgDvom6obz+7r9LSUr/ttttYTvvChvOeeuop2nbAgAE0b9OmDc3ZsFLfvn1pWzbkBMTDQh07dqQ528p5/376T4JTp07RPBrSYkt2AWDq1KmZGRvuAvhW7QDw9NNP05wdsx0d6f71r3+9SY8dbWm+fPnyzKy6upq2ZdtzL168GAcOHGjQUF+4nt/dlwHIurPsAWYRKWqa4SeSKBW/SKJU/CKJUvGLJErFL5IoFb9IovK6dXeLFi3oeDrbHhvgY7MPPvggbTtz5kyaR8uJ2Xj4nXfeSdtOmjSJ5tGy2uioajbW/t5779G248aNo3l0vHi0lPruu+/OzAYOHEjbzp49u0mPXVKS/eM9YsQI2nbKlCk0v/nmm2kebed+7bXXZmbRvxmbe7F69Wratj4984skSsUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLyfkR39+7dM/NXX32VtmdbGkdbhPXu3Zvm0fprdtzztGnTaNtom+io7x06dKB5v379MrMuXbrQtq+99hrNo/0eou2zjxw5kpl16tSJtmXHWAPxPAA2L+TZZ5+lbd9//32aV1ZW0nzHjh00f+655zKz9u3b07ZDhw7NzFq1akXb1qdnfpFEqfhFEqXiF0mUil8kUSp+kUSp+EUSpeIXSVRex/lra2vpuG+0Tzsbwzx27Bhte/ToUZqfOHGC5ps3b87Mzj//fNp23bp1NI+Oqt6+fTvN2RHdrN8A6LwLAFi5ciXNozFpdv/R9x2ddzB37lyas/ufPHkybfurX/2K5tH8iWi9P7su0RwE9m/KjkT/ID3ziyRKxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9Ioixar21mvQDMBFAGwAFMcfdHzewBAPcCOLPwebK7z2f3VVZW5l/96lcz8127dtG+sLHV6Lz1T37ykzRfsWIFzW+77bbMrKqqirYtLS2l+V/+8heaT5gwgebsunXu3Jm2XbZsGc2jvm/YsIHmbC+CaA+F/fv303z9+vU0HzVqVGYWzV8YPXo0zaOfpzfeeIPmbB+ErVu30rZsjsBvfvMbbN++3egd5DRkkk8NgO+5+1oz6wTgJTM7cyLBI+7+nw15IBEpLmHxu/tOADtzH1eZ2RsAep7rjonIufWRfuc3sz4ArgKwKnfTd8ys0symm1nXjDYTzWyNma2JtqsSkfxpcPGbWUcATwKY5O6HAfwaQD8Ag1D3yuDhs7Vz9ynuXuHuFe3atWuGLotIc2hQ8ZtZK9QV/ix3nwsA7r7b3U+7ey2AqQCydxUUkaITFr+ZGYBpAN5w91/Uu7283pfdAYBvvSsiRaUhQ33XAlgKYD2A2tzNkwGMQ91LfgewBcA3c28OZoqG+qIlns8//3xmFi2rjY7gjpamlpWVZWavv/46bXvo0CGaR0eTl5eX05xtK/7mm2/Sttdccw3No1/VouXK7HuLtre+6KKLaF5TU0Pzli1bNioDgPPOO4/mvXr1ovl9991H85EjR2Zm0ffFjvBesWIFDh061DxDfe6+DMDZ7oyO6YtIcdMMP5FEqfhFEqXiF0mUil8kUSp+kUSp+EUSldetu80MJSXZD/nOO+/Q9sOHD8/MoqWr0Th+tLS1bdu2mdmNN95I2+7du5fmr7zyCs1bt25Nc7Yt+ZAhQ2jbaNnszp106kbYNzaPpGPHjrTt8uXLaT5ixAiaHz58ODMbOHAgbbtp0yaaV1dX07yiooLm7Fj1aGk7+3mLjrmvT8/8IolS8YskSsUvkigVv0iiVPwiiVLxiyRKxS+SqHA9f7M+mNn7AOovPi8FwAfBC6dY+1as/QLUt8Zqzr71dne+eUVOXov/Qw9utsbd+WyIAinWvhVrvwD1rbEK1Te97BdJlIpfJFGFLv4pBX58plj7Vqz9AtS3xipI3wr6O7+IFE6hn/lFpEAKUvxmdrOZvWVmm8zsB4XoQxYz22Jm681snZmtKXBfppvZHjN7td5t3cxsoZltzP191mPSCtS3B8xse+7arTMzftTtuetbLzNbbGavm9lrZvavudsLeu1Ivwpy3fL+st/MWgLYAGAkgG0AVgMY5+588/s8MbMtACrcveBjwmZ2HYAjAGa6++W52/4DwH53/2nuP86u7v5vRdK3BwAcKfTJzbkDZcrrnywN4HYA/4ICXjvSr7EowHUrxDP/UACb3H2zu58CMBvAmAL0o+i5+xIAH9xtYwyAGbmPZ6DuhyfvMvpWFNx9p7uvzX1cBeDMydIFvXakXwVRiOLvCWBrvc+3obiO/HYAC8zsJTObWOjOnEVZvZORdgHIPkqoMMKTm/PpAydLF821a8yJ181Nb/h92LXuPhjAKADfzr28LUpe9ztbMQ3XNOjk5nw5y8nS/6+Q166xJ143t0IU/3YA9Q86uzB3W1Fw9+25v/cAeArFd/rw7jOHpOb+3lPg/vy/Yjq5+WwnS6MIrl0xnXhdiOJfDaC/mfU1s9YA7gYwrwD9+BAz65B7IwZm1gHA51F8pw/PAzA+9/F4AM8UsC//oFhObs46WRoFvnZFd+K1u+f9D4DRqHvH/20A9xeiDxn9uhjAK7k/rxW6bwB+j7qXgdWoe2/kHgDnA1gEYCOA5wF0K6K+PY6605wrUVdo5QXq27Woe0lfCWBd7s/oQl870q+CXDfN8BNJlN7wE0mUil8kUSp+kUSp+EUSpeIXSZSKXyRRKn6RRKn4RRL1fwp5YkSSP4AUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "generator = make_generator_model()\n",
    "\n",
    "noise = tf.random.normal([1, 100])\n",
    "generated_image = generator(noise, training=False)\n",
    "\n",
    "plt.imshow(generated_image[0, :, :, 0], cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造判别器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_discriminator_model():\n",
    "    model = tf.keras.Sequential()\n",
    "    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', \n",
    "                                     input_shape=[28, 28, 1]))\n",
    "    model.add(layers.LeakyReLU())\n",
    "    model.add(layers.Dropout(0.3))\n",
    "      \n",
    "    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n",
    "    model.add(layers.LeakyReLU())\n",
    "    model.add(layers.Dropout(0.3))\n",
    "       \n",
    "    model.add(layers.Flatten())\n",
    "    model.add(layers.Dense(1))\n",
    "     \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "判别器判别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([[-0.00016926]], shape=(1, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "discriminator = make_discriminator_model()\n",
    "decision = discriminator(generated_image)\n",
    "print (decision)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This method returns a helper function to compute cross entropy loss\n",
    "cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 判别器损失\n",
    "def discriminator_loss(real_output, fake_output):\n",
    "    real_loss = cross_entropy(tf.ones_like(real_output), real_output)\n",
    "    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)\n",
    "    total_loss = real_loss + fake_loss\n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成器损失\n",
    "def generator_loss(fake_output):\n",
    "    return cross_entropy(tf.ones_like(fake_output), fake_output)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_optimizer = tf.keras.optimizers.Adam(1e-4)\n",
    "discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "checkpoint保持"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
    "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
    "                                 discriminator_optimizer=discriminator_optimizer,\n",
    "                                 generator=generator,\n",
    "                                 discriminator=discriminator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.训练函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 50\n",
    "noise_dim = 100\n",
    "num_examples_to_generate = 16\n",
    "\n",
    "# We will reuse this seed overtime (so it's easier)\n",
    "# to visualize progress in the animated GIF)\n",
    "seed = tf.random.normal([num_examples_to_generate, noise_dim])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练迭代函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notice the use of `tf.function`\n",
    "# This annotation causes the function to be \"compiled\".\n",
    "@tf.function\n",
    "def train_step(images):\n",
    "    noise = tf.random.normal([BATCH_SIZE, noise_dim])\n",
    "\n",
    "    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
    "      generated_images = generator(noise, training=True)\n",
    "\n",
    "      real_output = discriminator(images, training=True)\n",
    "      fake_output = discriminator(generated_images, training=True)\n",
    "\n",
    "      gen_loss = generator_loss(fake_output)\n",
    "      disc_loss = discriminator_loss(real_output, fake_output)\n",
    "\n",
    "    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)\n",
    "    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)\n",
    "\n",
    "    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))\n",
    "    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataset, epochs):  \n",
    "  for epoch in range(epochs):\n",
    "    start = time.time()\n",
    "    \n",
    "    for image_batch in dataset:\n",
    "      train_step(image_batch)\n",
    "\n",
    "    # Produce images for the GIF as we go\n",
    "    display.clear_output(wait=True)\n",
    "    generate_and_save_images(generator,\n",
    "                             epoch + 1,\n",
    "                             seed)\n",
    "    \n",
    "    # Save the model every 15 epochs\n",
    "    if (epoch + 1) % 15 == 0:\n",
    "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
    "    \n",
    "    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))\n",
    "    \n",
    "  # Generate after the final epoch\n",
    "  display.clear_output(wait=True)\n",
    "  generate_and_save_images(generator,\n",
    "                           epochs,\n",
    "                           seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成和保存图像"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_save_images(model, epoch, test_input):\n",
    "  # Notice `training` is set to False. \n",
    "  # This is so all layers run in inference mode (batchnorm).\n",
    "  predictions = model(test_input, training=False)\n",
    "\n",
    "  fig = plt.figure(figsize=(4,4))\n",
    "  \n",
    "  for i in range(predictions.shape[0]):\n",
    "      plt.subplot(4, 4, i+1)\n",
    "      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n",
    "      plt.axis('off')\n",
    "        \n",
    "  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "train(train_dataset, EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成一张动图\n",
    "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
    "def display_image(epoch_no):\n",
    "  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_image(EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 6.训练过程的动图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n",
    "  filenames = glob.glob('image*.png')\n",
    "  filenames = sorted(filenames)\n",
    "  last = -1\n",
    "  for i,filename in enumerate(filenames):\n",
    "    frame = 2*(i**0.5)\n",
    "    if round(frame) > round(last):\n",
    "      last = frame\n",
    "    else:\n",
    "      continue\n",
    "    image = imageio.imread(filename)\n",
    "    writer.append_data(image)\n",
    "  image = imageio.imread(filename)\n",
    "  writer.append_data(image)\n",
    "    \n",
    "# A hack to display the GIF inside this notebook\n",
    "os.rename('dcgan.gif', 'dcgan.gif.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display.Image(filename=\"dcgan.gif.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
